package main

import (
	"encoding/json"
	"fmt"
	"math/rand"
	"time"

	"github.com/go-redis/redis"

	"github.com/jinzhu/gorm"
	_ "github.com/jinzhu/gorm/dialects/mysql"
)

type Thread struct {
	ThreadID int32 `gorm:"primary_key"`
	Text     string
	CTime    time.Time
}

func (this *Thread) RedisKey() string {
	return fmt.Sprintf("thread:%d", this.ThreadID)
}

func (this *Thread) RedisValue() []byte {
	b, _ := json.Marshal(this)
	return b
}

func (this *Thread) RedisTTL() time.Duration {
	return 24 * 60 * 60 * time.Second
}

func (this *Thread) RedisFailTTL() time.Duration {
	return 30 * time.Second
}

func AddThread(o *gorm.DB, c *redis.Client, thread *Thread) error {
	if err := o.Create(thread).Error; err != nil {
		return err
	}
	if err := c.Set(
		thread.RedisKey(),
		thread.RedisValue(),
		thread.RedisTTL()).Err(); err != nil {
		return err
	}
	return nil
}

func SetThread(o *gorm.DB, c *redis.Client, thread *Thread) error {
	if err := o.Save(thread).Error; err != nil {
		return err
	}
	if err := c.Set(
		thread.RedisKey(),
		thread.RedisValue(),
		thread.RedisTTL()).Err(); err != nil {
		return err
	}
	return nil
}

func GetThreads(o *gorm.DB, c *redis.Client, ids []int32) ([]*Thread, error) {
	threads := make([]*Thread, len(ids))
	for i := range threads {
		threads[i] = &Thread{ThreadID: ids[i]}
	}

	// Load From Cache
	keys := make([]string, len(ids))
	for i := range keys {
		keys[i] = threads[i].RedisKey()
	}
	vals, err := c.MGet(keys...).Result()
	if err != nil {
		return nil, err
	}
	failIds := []int32{}
	succeedIdxs := []int{}
	for i := range vals {
		val, ok := vals[i].(string)
		if ok {
			err := json.Unmarshal([]byte(val), threads[i])
			if err != nil {
				ok = false
			}
		}
		if ok {
			succeedIdxs = append(succeedIdxs, i)
		} else {
			failIds = append(failIds, ids[i])
		}
	}

	// Refresh Expire
	{
		pl := c.Pipeline()
		for _, idx := range succeedIdxs {
			pl.Expire(threads[idx].RedisKey(), threads[idx].RedisTTL())
		}
		_, _ = pl.Exec() // TODO: waring log
	}

	// Load From Database
	if len(failIds) != 0 {
		fmt.Println("Cache Miss:", failIds)

		failThreads := make([]*Thread, 0, len(failIds))
		err = o.Where("thread_id IN (?)", failIds).Find(&failThreads).Error
		if err != nil {
			return nil, err
		}

		pl := c.Pipeline()
		for i := range failThreads {
			pl.Set(failThreads[i].RedisKey(), failThreads[i].RedisValue(), failThreads[i].RedisTTL())
		}
		_, _ = pl.Exec() // TODO: waring log

		for i := range failThreads {
			for j := range threads {
				if failThreads[i].ThreadID == threads[j].ThreadID {
					threads[j] = failThreads[i]
				}
			}
		}

		// Double Fail
		if len(failThreads) != len(failIds) {
			doubleFailIds := []int32{}
			for i := range failIds {
				fail := true
				for j := range failThreads {
					if failThreads[j].ThreadID == failIds[i] {
						fail = false
						break
					}
				}
				if fail {
					doubleFailIds = append(doubleFailIds, failIds[i])
				}
			}

			// avoid pass through
			if len(doubleFailIds) != 0 {
				pl := c.Pipeline()
				for i := range doubleFailIds {
					t := &Thread{ThreadID: doubleFailIds[i]}
					pl.Set(t.RedisKey(), "{}", t.RedisFailTTL())
				}
				_, err := pl.Exec()
				if err != nil {
					return nil, err
				}
			}
		}
	}

	return threads, nil
}

func init() {
	redis.SetLogger(nil)
	fmt.Println("redis.Nil", redis.Nil)
}

func main() {
	o, err := gorm.Open("mysql", "root:nikki@(127.0.0.1:3306)/student?timeout=5s&parseTime=true&loc=Local&charset=utf8")
	if err != nil {
		panic(err)
	}
	defer o.Close()
	o.AutoMigrate(&Thread{})

	c := redis.NewClient(&redis.Options{
		Addr: "localhost:6379",
	})
	defer c.Close()

	rand.Seed(time.Now().UnixNano())

	//AddThread(o, c, &Thread{Text: "first", CTime: time.Now()})
	//AddThread(o, c, &Thread{Text: "second", CTime: time.Now().Add(time.Second)})
	//AddThread(o, c, &Thread{Text: "third", CTime: time.Now().Add(time.Second)})

	threads, err := GetThreads(o, c, []int32{1, 2, -1, 3})
	if err != nil {
		panic(err)
	}

	for _, thread := range threads {
		fmt.Println(thread.ThreadID, thread.Text)
	}

	// Delete Cache
	err = c.Del((&Thread{ThreadID: 1}).RedisKey()).Err()
	if err != nil {
		panic(err)
	}
	fmt.Println("\nAfter Delete Cache:")

	threads, err = GetThreads(o, c, []int32{1, 2, -1, 3})
	if err != nil {
		panic(err)
	}

	for _, thread := range threads {
		fmt.Println(thread.ThreadID, thread.Text)
	}

	// Modify Cache
	SetThread(o, c, &Thread{ThreadID: 1, Text: "first" + fmt.Sprintf("_%d", rand.Intn(100)), CTime: time.Now()})
	fmt.Println("\nAfter Modify Cache:")

	threads, err = GetThreads(o, c, []int32{1, 2, -1, 3})
	if err != nil {
		panic(err)
	}

	for _, thread := range threads {
		fmt.Println(thread.ThreadID, thread.Text)
	}
}
